Skip to content

Add LFM2.5-VL export with CUDA/AOTI backend#18823

Open
vincentzed wants to merge 1 commit intopytorch:mainfrom
vincentzed:vz-lfm2516b-squashed
Open

Add LFM2.5-VL export with CUDA/AOTI backend#18823
vincentzed wants to merge 1 commit intopytorch:mainfrom
vincentzed:vz-lfm2516b-squashed

Conversation

@vincentzed
Copy link
Copy Markdown

@vincentzed vincentzed commented Apr 10, 2026

Summary

Add LFM2.5-VL (450M and 1.6B) as a multi-method PTE with three methods: vision_encoder, token_embedding, and text_decoder for CUDA/AOTI. LFM was not supported via CUDA before. Originally this PR started with XNN, then make it work with CUDA. Unfortunately, I have not got chance to test with XNN. Also, it requires adding ci/unit test probably too for pipeline.

Context: On very small <500M model, Llama cpp and executorch both deliver good perforamnce for low latency use case (i.e, vs SGlang, higher overhead framework at concurrency=1). This is the FIRST step to reach towards unification benchmark + rigorous measurement of such overhead.

HW: NVIDIA B300, torch 2.11, CUDA 13.0
Results: 333-400 decode tok/s, 435-454 prefill tok/s via llama_main C++ runner.

Key implementation details

  • Conv layer state support. via attn_options["conv_states"]for AOTI compatibility. Before,register_buffer is still used for XNNPack.
  • The usage of mark_static_address (same as transformers' StaticCache for Gemma3) so AOTI can trace it.
  • Manual depthwise conv (pointwise multiply+sum) replaces nn.Conv1d(groups=dim) — Triton has no template for depthwise conv1d with dynamic seq_len (or at least I was not able to get this working correctly). If there is an alternative... I would appreciate pointers on its implementation (Did not find in repo too).

Prefill sweep (B300, bf16)

ISL Latency (ms) Throughput (tok/s)
32 8.0 4,002
128 15.8 8,105
512 19.0 26,974
1,024 21.0 48,758

Sample outputs (llama_main)

Prompt: "The capital of France is"
→ Paris.

Prompt: "List the planets in our solar system in order from the sun."
→ 1. Mercury 2. Venus 3. Earth 4. Mars 5. Jupiter 6. Saturn 7. Uranus 8. Neptune

Prompt: "Describe this image in detail." (glacier photo)
→ The image captures a breathtaking view of a majestic glacier, its icy blue surface
  glistening under the bright sunlight...

Test plan

=== Vision-Language ===

     Prompt: Describe this image in detail.
     Response: The image captures a breathtaking view of a majestic glacier, with its icy blue hue dominating the scene. The glacier stretches across the frame, its
     surface marked by intricate patterns and ridges that hint at the geological forces shaping it. In the background, a range of snow-capped mountains rises, their
     peaks contrasting sharply against a brilliant blue sky. The sky itself is a canvas of white clouds, some tinged with hints of orange, suggesting the time of day to
      be either dawn or dusk. The foreground features a serene body of water, its surface dotted with small ripples and bubbles, reflecting the tranquility of the
     Time: 12.13s

     Prompt: What objects do you see in this image?
     Response: In this stunning image, I see a breathtaking natural landscape featuring a large glacier in the foreground. The glacier is a striking blue color with
     white patches, creating a beautiful contrast against the surrounding environment.

     The scene is set against a backdrop of majestic mountains, which appear to be part of a larger mountain range. These mountains are covered in lush green
     vegetation, adding depth and richness to the overall composition.

     Above the mountains, the sky is a brilliant blue with scattered white clouds, creating a sense of vastness and openness.

     The water in the foreground is a vibrant turquoise color, reflecting the sunlight and
     Time: 11.66s

     === Text-Only ===

     Prompt: The capital of France is
     Response: Paris.
     Time: 0.28s

     Prompt: Explain the difference between a compiler and an interpreter in two sentences.
     Response: A compiler is a process that converts high-level programming language code into machine code, which can be executed by a computer. An interpreter, on the
      other hand, is a program that executes code line by line, without converting it to machine code.
     Time: 4.59s

     Prompt: What is the speed of light in meters per second?
     Response: The speed of light in a vacuum is approximately 299,792,458 meters per second. This value is derived from the fundamental laws of physics and is a
     constant that applies to all observers, regardless of their relative motion or location.
     Time: 4.36s

Export (multi-method PTE)

cd /path/to/workdir
python examples/models/lfm2_5_vl/export_lfm2_5_vl.py \
  --model_dir LiquidAI/LFM2.5-VL-450M --dtype bf16 \
  --output lfm2_5_vl_bf16_cuda.pte
# Produces: lfm2_5_vl_bf16_cuda.pte + aoti_cuda_blob.ptd

Run with llama_main (single-method PTE)

export_single_method.py
"""Export LFM2.5-VL-450M as single-method PTE compatible with llama_main."""

from __future__ import annotations

import logging
from pathlib import Path

import torch
from torch.export import Dim
from torch.export._trace import _export
from torch.nn.attention import SDPBackend

from executorch.backends.cuda.cuda_backend import CudaBackend
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
from executorch.examples.models.lfm2.short_conv import ShortConvBlock
from executorch.examples.models.lfm2_5_vl.model import Lfm2p5VlModel, MAX_SEQ_LEN
from executorch.exir import (
    EdgeCompileConfig,
    ExecutorchBackendConfig,
    to_edge_transform_and_lower,
)
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass

logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")

try:
    from torch._inductor.codecache import cuda_compile_utils

    _orig_nvcc_arch = cuda_compile_utils._nvcc_arch_as_compile_option

    def _patched_nvcc_arch() -> str:
        return "103a" if cuda_compile_utils.cuda_env.get_cuda_arch() == "103" else _orig_nvcc_arch()

    cuda_compile_utils._nvcc_arch_as_compile_option = _patched_nvcc_arch
except (ImportError, AttributeError):
    pass

_PARAMS = Path(__file__).parent / ".." / "executorch" / "examples" / "models" / "lfm2_5_vl" / "config" / "lfm2_5_vl_450m_config.json"
_MODEL_DIR = Path("LFM2-VL-450M")
_OUTPUT = Path("lfm2_5_vl_llama_cuda.pte")


class _LlamaCompatModel(torch.nn.Module):
    """forward(input_ids, input_pos) -> logits, matching llama_main interface."""

    def __init__(
        self, lfm2: torch.nn.Module, conv_dim: int, conv_indices: list[int],
        *, dtype: torch.dtype, device: str,
    ) -> None:
        super().__init__()
        self.embed = lfm2.model_.model.language_model.get_input_embeddings()
        self.text_model = lfm2.text_model
        self.conv_indices = conv_indices

        for idx in conv_indices:
            buf = torch.zeros(1, conv_dim, 2, dtype=dtype, device=device)
            self.register_buffer(f"conv_state_{idx}", buf, persistent=False)
            if not torch.compiler.is_compiling():
                torch._dynamo.mark_static_address(buf)

    def forward(self, input_ids: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
        embeddings = self.embed(input_ids)
        conv_states = {idx: getattr(self, f"conv_state_{idx}") for idx in self.conv_indices}
        out = self.text_model(None, {"input_pos": input_pos, "conv_states": conv_states}, embeddings)
        if isinstance(out, tuple):
            out = out[0]
        return out.contiguous()


def main() -> None:
    logging.info("Loading model...")
    lfm2_model = Lfm2p5VlModel(
        model_dir=str(_MODEL_DIR),
        max_seq_len=MAX_SEQ_LEN,
        max_context_len=MAX_SEQ_LEN,
        params_path=str(_PARAMS),
        use_sdpa_with_kv_cache_op=False,
    )
    lfm2 = lfm2_model.get_eager_model().to(dtype=torch.bfloat16, device="cuda")

    conv_indices = [i for i, layer in enumerate(lfm2.text_model.layers) if isinstance(layer, ShortConvBlock)]
    model = _LlamaCompatModel(lfm2, lfm2.text_model_args.dim, conv_indices, dtype=torch.bfloat16, device="cuda")

    # Mark KV cache buffers after device migration
    for module in model.text_model.modules():
        for name, buf in module.named_buffers(recurse=False):
            if name in ("k_cache", "v_cache"):
                torch._dynamo.mark_static_address(buf)

    seq = 8
    token_dim = Dim("token_dim", min=1, max=MAX_SEQ_LEN - 1)
    example_ids = torch.randint(1, 65000, (1, seq), dtype=torch.int64, device="cuda")
    example_pos = torch.arange(seq, dtype=torch.int64, device="cuda")

    logging.info("Exporting...")
    with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
        ep = _export(
            model, (example_ids, example_pos),
            dynamic_shapes=({1: token_dim}, {0: token_dim}),
            strict=False,
            prefer_deferred_runtime_asserts_over_guards=True,
        )

    logging.info("Lowering to CUDA")
    compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")]
    et_prog = to_edge_transform_and_lower(
        {"forward": ep},
        partitioner={"forward": [CudaPartitioner(compile_specs)]},
        compile_config=EdgeCompileConfig(_check_ir_validity=False, _skip_dim_order=True),
        constant_methods={"get_max_seq_len": MAX_SEQ_LEN, "get_vocab_size": lfm2.text_model_args.vocab_size, "use_kv_cache": True, "get_eos_ids": [7]},
    )

    et_program = et_prog.to_executorch(
        ExecutorchBackendConfig(
            memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
            sym_shape_eval_pass={"forward": ConstraintBasedSymShapeEvalPass()},
        )
    )

    logging.info("Saving %s", _OUTPUT)
    with open(_OUTPUT, "wb") as f:
        et_program.write_to_file(f)
    et_program.write_tensor_data_to_file(".")
    logging.info("Done")


if __name__ == "__main__":
    main()
# Build runner
make llama-cuda

# Export single-method
python export_single_method.py

# Run
cmake-out/examples/models/llama/llama_main \
  --model_path lfm2_5_vl_llama_cuda.pte \
  --data_paths aoti_cuda_blob.ptd \
  --tokenizer_path LFM2-VL-450M/tokenizer.json \
  --prompt $'<|startoftext|><|im_start|>user\nThe capital of France is<|im_end|>\n<|im_start|>assistant\n' \
  --max_new_tokens 64 --temperature 0.1 --warmup true

Python inference runner

run_lfm2vl.py
"""Run LFM2.5-VL-450M from an exported PTE+PTD on CUDA."""

from __future__ import annotations

import time
from pathlib import Path

import numpy as np
import torch
from PIL import Image
from transformers import AutoProcessor
from executorch.extension.pybindings.portable_lib import _load_for_executorch

PTE_PATH = Path("lfm2_5_vl_bf16_cuda.pte")
PTD_PATH = Path("aoti_cuda_blob.ptd")
MODEL_DIR = Path("LFM2-VL-450M")

IMAGE_TOKEN_ID = 396
EOS_ID = 7
VISION_INPUT_SIZE = 512

# Model card recommended sampling parameters
TEMPERATURE = 0.1
MIN_P = 0.15
REPETITION_PENALTY = 1.05


def _load_image_pixels(path: Path) -> torch.Tensor:
    """Load an image as [1, 3, 512, 512] NCHW float32 in [0, 255]."""
    img = Image.open(path).convert("RGB").resize((VISION_INPUT_SIZE, VISION_INPUT_SIZE))
    return torch.from_numpy(np.array(img)).permute(2, 0, 1).unsqueeze(0).float()


def _embed_tokens(module, input_ids: torch.Tensor) -> torch.Tensor:
    return module.run_method("token_embedding", [input_ids])[0].contiguous()


def _decode_step(module, embeddings: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
    return module.run_method("text_decoder", [embeddings.contiguous(), input_pos])[0]


def _build_embeddings(
    module,
    input_ids: torch.Tensor,
    image_path: Path | None,
) -> torch.Tensor:
    """Build the full embedding sequence, splicing in vision embeddings if needed."""
    if image_path is None or IMAGE_TOKEN_ID not in input_ids[0]:
        return _embed_tokens(module, input_ids)

    positions = (input_ids[0] == IMAGE_TOKEN_ID).nonzero(as_tuple=True)[0]
    first, last = positions[0].item(), positions[-1].item()

    before = _embed_tokens(module, input_ids[:, :first])
    after = _embed_tokens(module, input_ids[:, last + 1 :])

    pixels = _load_image_pixels(image_path).contiguous()
    image = module.run_method("vision_encoder", [pixels])[0].contiguous()

    return torch.cat([before, image, after], dim=1)


def _sample_token(
    logits: torch.Tensor,
    generated: list[int],
    temperature: float = TEMPERATURE,
    min_p: float = MIN_P,
    repetition_penalty: float = REPETITION_PENALTY,
) -> int:
    """Sample next token with temperature, min-p filtering, and repetition penalty."""
    scores = logits.float()

    # Repetition penalty: reduce logits for already-generated tokens
    if generated and repetition_penalty != 1.0:
        prev_tokens = torch.tensor(generated, dtype=torch.long, device=scores.device)
        token_scores = scores[prev_tokens]
        token_scores = torch.where(
            token_scores > 0,
            token_scores / repetition_penalty,
            token_scores * repetition_penalty,
        )
        scores[prev_tokens] = token_scores

    if temperature <= 0:
        return scores.argmax(dim=-1).item()

    probs = torch.softmax(scores / temperature, dim=-1)

    # Min-p filtering: zero out tokens below min_p * max_prob
    if min_p > 0:
        top_prob = probs.max()
        probs[probs < min_p * top_prob] = 0.0

    return torch.multinomial(probs, num_samples=1).item()


def generate(
    module,
    processor: AutoProcessor,
    prompt: str,
    *,
    image_path: Path | None = None,
    max_new_tokens: int = 128,
) -> str:
    content: list[dict[str, str]] = []
    if image_path is not None:
        content.append({"type": "image"})
    content.append({"type": "text", "text": prompt})

    text = processor.apply_chat_template(
        [{"role": "user", "content": content}], add_generation_prompt=True
    )
    input_ids = processor.tokenizer.encode(text, return_tensors="pt")

    embeddings = _build_embeddings(module, input_ids, image_path).contiguous()
    seq_len = embeddings.shape[1]
    logits = _decode_step(module, embeddings, torch.arange(seq_len, dtype=torch.int64))

    generated: list[int] = []
    cur_pos = seq_len

    for _ in range(max_new_tokens):
        last_logits = logits[:, -1, :].squeeze(0) if logits.dim() == 3 else logits.squeeze(0)
        token_id = _sample_token(last_logits, generated)
        if token_id == EOS_ID:
            break
        generated.append(token_id)

        token_embed = _embed_tokens(module, torch.tensor([[token_id]], dtype=torch.int64))
        logits = _decode_step(module, token_embed, torch.tensor([cur_pos], dtype=torch.int64))
        cur_pos += 1

    return processor.tokenizer.decode(generated, skip_special_tokens=True)


def main() -> None:
    module = _load_for_executorch(str(PTE_PATH), str(PTD_PATH))
    processor = AutoProcessor.from_pretrained(str(MODEL_DIR))

    test_image = Path("/tmp/test_image.jpg")

    sections: list[tuple[str, list[tuple[str, Path | None]]]] = [
        (
            "Vision-Language",
            [
                ("Describe this image in detail.", test_image),
                ("What objects do you see in this image?", test_image),
            ],
        ),
        (
            "Text-Only",
            [
                ("The capital of France is", None),
                ("Explain the difference between a compiler and an interpreter in two sentences.", None),
                ("What is the speed of light in meters per second?", None),
            ],
        ),
    ]

    for section_name, prompts in sections:
        print(f"\n=== {section_name} ===")
        for prompt, img in prompts:
            print(f"\nPrompt: {prompt}")
            t0 = time.perf_counter()
            response = generate(module, processor, prompt, image_path=img)
            elapsed = time.perf_counter() - t0
            print(f"Response: {response}")
            print(f"Time: {elapsed:.2f}s")


if __name__ == "__main__":
    main()

Benchmark

bench_lfm2vl.py
"""Benchmark LFM2.5-VL-450M on ExecuTorch CUDA — matches llama_main metrics."""

from __future__ import annotations

import time
from pathlib import Path

import torch
from transformers import AutoProcessor
from executorch.extension.pybindings.portable_lib import _load_for_executorch

PTE_PATH = Path("lfm2_5_vl_bf16_cuda.pte")
PTD_PATH = Path("aoti_cuda_blob.ptd")
MODEL_DIR = Path("LFM2-VL-450M")

EOS_ID = 7
IMAGE_TOKEN_ID = 396

TEMPERATURE = 0.1
TOP_P = 0.9


def _embed(module, ids: torch.Tensor) -> torch.Tensor:
    return module.run_method("token_embedding", [ids])[0].contiguous()


def _decode(module, emb: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
    return module.run_method("text_decoder", [emb.contiguous(), pos])[0]


def _sample(logits: torch.Tensor) -> int:
    if TEMPERATURE <= 0:
        return torch.argmax(logits, dim=-1).item()
    probs = torch.softmax(logits / TEMPERATURE, dim=-1)
    probs_sort, probs_idx = torch.sort(probs, descending=True)
    cum = torch.cumsum(probs_sort, dim=-1)
    mask = (cum - probs_sort) > TOP_P
    probs_sort[mask] = 0.0
    probs_sort /= probs_sort.sum()
    return torch.gather(probs_idx, -1, torch.multinomial(probs_sort, 1)).item()


def benchmark_text(
    module, tokenizer, prompt: str, *, max_new_tokens: int = 128, warmup: bool = True
) -> None:
    tokens = tokenizer.encode(prompt)
    if isinstance(tokens, torch.Tensor):
        tokens = tokens.squeeze().tolist()
    prompt_len = len(tokens)

    if warmup:
        ids = torch.tensor([tokens], dtype=torch.int64)
        emb = _embed(module, ids)
        pos = torch.arange(emb.shape[1], dtype=torch.int64)
        _decode(module, emb, pos)

    # --- Prefill ---
    ids = torch.tensor([tokens], dtype=torch.int64)
    torch.cuda.synchronize()
    t_prefill = time.perf_counter()
    emb = _embed(module, ids)
    pos = torch.arange(emb.shape[1], dtype=torch.int64)
    logits = _decode(module, emb, pos)
    torch.cuda.synchronize()
    prefill_time = time.perf_counter() - t_prefill

    last = logits[:, -1, :].squeeze(0) if logits.dim() == 3 else logits.squeeze(0)
    cur_token = _sample(last)
    generated = [cur_token]
    cur_pos = prompt_len

    # --- Decode ---
    torch.cuda.synchronize()
    t_decode = time.perf_counter()
    while len(generated) < max_new_tokens:
        tok_emb = _embed(module, torch.tensor([[cur_token]], dtype=torch.int64))
        logits = _decode(module, tok_emb, torch.tensor([cur_pos], dtype=torch.int64))
        last = logits[:, -1, :].squeeze(0) if logits.dim() == 3 else logits.squeeze(0)
        cur_token = _sample(last)
        if cur_token == EOS_ID:
            break
        generated.append(cur_token)
        cur_pos += 1
    torch.cuda.synchronize()
    decode_time = time.perf_counter() - t_decode

    decode_tokens = len(generated) - 1  # first token came from prefill
    text = tokenizer.decode(generated, skip_special_tokens=True)

    print(f"  Prompt ({prompt_len} tokens): {prompt[:60]}...")
    print(f"  Output ({len(generated)} tokens): {text[:80]}...")
    print(f"  Prefill:  {prefill_time*1000:.1f} ms  |  {prompt_len/prefill_time:.0f} tok/s")
    print(f"  TTFT:     {prefill_time*1000:.1f} ms")
    if decode_tokens > 0:
        print(f"  Decode:   {decode_time*1000:.1f} ms  |  {decode_tokens/decode_time:.1f} tok/s")
    print()


def benchmark_prefill_sweep(module, tokenizer) -> None:
    """Prefill-only benchmark across different input lengths."""
    print("=== Prefill Sweep ===")
    print(f"{'ISL':>6}  {'Latency (ms)':>12}  {'Throughput (tok/s)':>18}")
    print("-" * 42)

    for isl in [32, 64, 128, 256, 512, 1024]:
        ids = torch.randint(10, 65000, (1, isl), dtype=torch.int64)

        # Warmup
        emb = _embed(module, ids)
        pos = torch.arange(isl, dtype=torch.int64)
        _decode(module, emb, pos)

        # Timed (5 runs, take median)
        times = []
        for _ in range(5):
            torch.cuda.synchronize()
            t0 = time.perf_counter()
            emb = _embed(module, ids)
            pos = torch.arange(isl, dtype=torch.int64)
            _decode(module, emb, pos)
            torch.cuda.synchronize()
            times.append(time.perf_counter() - t0)

        times.sort()
        median = times[len(times) // 2]
        print(f"{isl:>6}  {median*1000:>12.2f}  {isl/median:>18.0f}")

    print()


def main() -> None:
    print(f"Loading {PTE_PATH} + {PTD_PATH}\n")
    module = _load_for_executorch(str(PTE_PATH), str(PTD_PATH))
    processor = AutoProcessor.from_pretrained(str(MODEL_DIR))
    tokenizer = processor.tokenizer

    # --- Text generation benchmarks ---
    prompts = [
        "The capital of France is",
        "Explain the difference between a compiler and an interpreter in two sentences.",
        "Write a short paragraph about the history of artificial intelligence.",
        "What is the speed of light in meters per second? Give just the number.",
    ]

    print("=== Text Generation (warmup + timed) ===\n")
    for prompt in prompts:
        benchmark_text(module, tokenizer, prompt, max_new_tokens=64)

    # --- Prefill sweep ---
    benchmark_prefill_sweep(module, tokenizer)


if __name__ == "__main__":
    main()
python bench_lfm2vl.py

Verification status

  • CUDA/AOTI export (multi-method: vision_encoder + token_embedding + text_decoder)
  • CUDA/AOTI export (single-method: forward, for llama_main)
  • Text-only generation quality (Paris, compiler/interpreter, speed of light, planets)
  • Vision-language generation quality (glacier image: coherent multi-sentence descriptions)
  • llama_main C++ runner (333-400 decode tok/s)
  • Python pybindings runner
  • Prefill sweep benchmark (ISL 32-1024)
  • XNNPack LFM2 text-only export still works (short_conv.py has dual state path but untested)
  • XNNPack LFM2.5-VL export (vision + XNNPack text decoder)
  • CI tests

Known limitations / future work

  • Blackwell sm_103 arch workaround: monkey-patches torch._inductor private API to fix nvcc/Triton PTX mismatch. Fragile; should be fixed upstream in PyTorch (relevant code_nvcc_arch_as_compile_option maps 103→100f, should be 103→103a).
  • Emitter CUDA storage fix: exir/emit/_emitter.py copies CUDA tensor storage to CPU before ctypes.data_ptr() read. This is a general fix, not LFM2.5-VL-specific — should be upstreamed as a standalone PR.
  • Conv state dynamic getattr pattern: _Decoder and _LlamaCompatModel use register_buffer(f"conv_state_{idx}") + getattr(self, f"conv_state_{idx}"). Works but violates the "no dynamic setattr/getattr" style guideline. Could use a list-based approach instead.
  • Batched export: batch_size>1 works for export but KV cache is pre-allocated at max_batch_size, consuming significant memory. batch=2048 OOMs during AOTI autotuning.
  • Vision encoder fixed to 512×512: the exported vision encoder bakes in normalization and patchification for a single 512×512 image. Multi-image / variable-resolution / tiling (as described in the model card) is not supported.
  • No llava_main integration: the multi-method PTE (vision_encoder + token_embedding + text_decoder) follows the LLaVA runner pattern but hasn't been tested with the actual llava_main C++ binary.
  • lm_head weight tying assumption: convert_weights.py assumes lm_head.weight == tok_embeddings.weight (tied embeddings). If a future checkpoint untied them, the lm_head weights would be silently ignored.

cc @larryliu0820 @mergennachin @cccclai @helunwencser @jackzhxng

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 10, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18823

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 1 Cancelled Job, 2 Unrelated Failures

As of commit baf48bb with merge base 273aee9 (image):

NEW FAILURE - The following job has failed:

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla
Copy link
Copy Markdown

meta-cla bot commented Apr 10, 2026

Hi @vincentzed!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@meta-cla
Copy link
Copy Markdown

meta-cla bot commented Apr 10, 2026

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 10, 2026
@vincentzed vincentzed marked this pull request as ready for review April 10, 2026 21:39
Copilot AI review requested due to automatic review settings April 10, 2026 21:39
@vincentzed vincentzed force-pushed the vz-lfm2516b-squashed branch from e61e728 to baf48bb Compare April 10, 2026 21:39
@vincentzed
Copy link
Copy Markdown
Author

Hello @Gasoonjia. I realize there is no CC list. Do you think you could help give it a review, or point me to the right person. Thanks! in advance.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an ExecuTorch export path for LiquidAI’s LFM2.5-VL models targeting the CUDA/AOTI backend, including a new multi-method PTE export pipeline and the required model/runtime adaptations (notably conv-state handling for hybrid conv/attention layers).

Changes:

  • Introduces examples/models/lfm2_5_vl/ (model wrapper, HF->ET weight remap, export script, and configs) to export vision_encoder, token_embedding, and text_decoder methods.
  • Updates LFM2 short-conv blocks to support “state-as-IO” via attn_options["conv_states"] and adds layer_idx wiring from the transformer constructor.
  • Fixes constant serialization in the emitter by copying non-CPU storages to CPU before reading bytes.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
exir/emit/_emitter.py Ensures constant tensor storage is moved to CPU before byte-serialization when storage is on a non-CPU device.
examples/models/llama/llama_transformer.py Passes layer_idx into ShortConvBlock for per-layer conv state mapping.
examples/models/lfm2/short_conv.py Refactors short conv to support explicit conv-state IO for AOTI and implements a manual depthwise conv path.
examples/models/lfm2_5_vl/model.py Adds an ExecuTorch-friendly LFM2.5-VL wrapper integrating HF vision tower + ET text transformer.
examples/models/lfm2_5_vl/export_lfm2_5_vl.py Adds CUDA/AOTI multi-method export pipeline producing PTE + PTD blob.
examples/models/lfm2_5_vl/convert_weights.py Adds a HF->ET key remapping utility for text-decoder weights.
examples/models/lfm2_5_vl/config/lfm2_5_vl_450m_config.json ModelArgs config for 450M variant.
examples/models/lfm2_5_vl/config/lfm2_5_vl_1_6b_config.json ModelArgs config for 1.6B variant.
examples/models/lfm2_5_vl/init.py Exposes Lfm2p5VlModel and convert_weights from the new package.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +88 to +93
if attn_options is not None and "conv_states" in attn_options:
if conv_state is not None:
conv_state.copy_(new_conv_state)
states = dict(attn_options["conv_states"])
states[self.layer_idx] = new_conv_state
update["conv_states"] = states
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When attn_options contains conv_states, this block mutates the provided state via conv_state.copy_(...) but then stores new_conv_state (a freshly allocated tensor from cat) back into the returned conv_states dict. In Transformer._forward_layers, that returned dict is merged into attn_options_, so the next layer call will read a non-static tensor and can break the intended AOTI "static address" state path. Also, dict(attn_options["conv_states"]) will throw if the key exists but the value is None. Consider: (1) reading conv_states = attn_options.get("conv_states") and ensuring it’s a dict before copying, and (2) if conv_state is provided, keep that same tensor in the returned mapping (after the in-place update) rather than replacing it with new_conv_state.

Copilot uses AI. Check for mistakes.
Comment on lines +50 to +55
# Manual depthwise conv — Triton has no template for nn.Conv1d
# with groups=dim and dynamic sequence length.
w = self.conv.weight[:, 0, :]
conv_out = Bx[..., :-2] * w[:, 0:1] + Bx[..., 1:-1] * w[:, 1:2] + Bx[..., 2:] * w[:, 2:3]

def reset_cache(self):
self.conv_state.zero_()
y = self.out_proj((C * conv_out).transpose(-1, -2).contiguous())
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShortConv.forward implements the convolution manually using self.conv.weight, but it ignores self.conv.bias when bias=True. This makes the bias argument silently incorrect. Either add the bias term to conv_out or enforce bias=False (e.g., via an assertion and/or by removing the parameter) to avoid surprising behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +56 to +60
orig = embeddings.position_embedding.weight.data
sqrt_n = int(math.sqrt(orig.shape[0]))

grid = orig.reshape(sqrt_n, sqrt_n, -1).permute(2, 0, 1).unsqueeze(0)
resized = F.interpolate(
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sqrt_n = int(math.sqrt(orig.shape[0])) is used to reshape positional embeddings into a square grid, but this will silently truncate when orig.shape[0] is not a perfect square and then fail or mis-reshape. It would be safer to assert sqrt_n * sqrt_n == orig.shape[0] (or handle the non-square case explicitly) before reshape.

Copilot uses AI. Check for mistakes.
Comment on lines +78 to +84
def image_embedding(self, nchw_pixels: torch.Tensor) -> torch.Tensor:
"""[B, 3, 512, 512] float32 pixels in [0, 255] -> [B, 256, D]."""
x = (nchw_pixels / 255.0 - 0.5) / 0.5

x = x.unfold(2, PATCH_SIZE, PATCH_SIZE).unfold(3, PATCH_SIZE, PATCH_SIZE)
x = x.permute(0, 2, 3, 4, 5, 1).reshape(1, FIXED_H * FIXED_W, PATCH_SIZE * PATCH_SIZE * 3)

Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image_embedding hard-codes batch size 1 via .reshape(1, ...) and later returns projected.reshape(1, ...), but the docstring and type hints imply it supports [B, ...]. If batch size is intentionally fixed to 1, consider asserting nchw_pixels.shape[0] == 1 and updating the docstring; otherwise, preserve B through the reshapes so the method behaves correctly for B>1.

Copilot uses AI. Check for mistakes.
Comment on lines +145 to +152
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
return torch.export._trace._export(
_Decoder(lfm2.text_model, dim, conv_indices, dtype=dtype, device=device),
(example_emb, example_pos),
dynamic_shapes=({1: token_dim}, {0: token_dim}),
strict=False,
prefer_deferred_runtime_asserts_over_guards=True,
)
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This uses the private API torch.export._trace._export, which is not stable and may break across PyTorch versions. If possible, prefer the public torch.export.export(...) API; otherwise, consider isolating this behind a small helper with a clear comment/version guard so failures are easier to diagnose when PyTorch internals change.

Copilot uses AI. Check for mistakes.
@nil-is-all nil-is-all added module: llm Issues related to LLM examples and apps, and to the extensions/llm/ code module: cuda Issues related to the AOTI CUDA backend labels Apr 14, 2026
"""Depthwise short convolution with dual state management.

Supports two modes:
1. State-as-IO: caller passes conv_state in and receives new state back.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be able to handle mutable state. It does with the regular KV Cache in mha

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are correct, let me fix this.
Firstly when I was using AOTI I reached PendingUnbackedSymbolNotFound.
My error was use register_buffer("conv_state") + copy_() so I use attn_options["conv_states"]. But then mark_static_address makes it work (which KV caache works as).

That is:
This code SHOULD use register_buffer("conv_state") + mark_static_address(conv_state) + and conv_state.copy_ (But I should fix this). AFAIK there is no remarkable performance difference for the moment.

Comment on lines +63 to +65
def _patched_nvcc_arch() -> str:
arch = cuda_compile_utils.cuda_env.get_cuda_arch()
return "103a" if arch == "103" else _orig_nvcc_arch()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Author

@vincentzed vincentzed Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be related to using sglang docker image instead (I realize this is improper). Without it I believe it will work OK. I use 2.11.

Edit:
On b300 torch._inductor.codecache.cuda_compile_utils._nvcc_arch_as_compile_option return 100f. Then nvcc choose -gencode arch=compute_100f,code=sm_100f. But triton PTX has .target sm_103a which clash (though they should technically work interchangeably). SM version specified by .target is higher than default SM version assumed

Or please see this comment
pytorch/pytorch#175951 (comment)

)

storage = typing.cast(torch.UntypedStorage, spec.storage)
if spec.allocated_memory != 0 and storage.device.type != "cpu":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are non cpu tensors making it into the emitter? Is this related to the state getting passed around as IO?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it's unrelated (I may be slightly incorrect)
It's exported on CUDA so AOTI compile cuda kernel from graph and also mark_static_address operate on KVCache buffer (both require model = cuda on export time)

The emitter in the PTE step does: ctypes.cast(storage.data_ptr(), ...) (which has segfault on GPU memory).

I am not sure if there is a better fix for this.

Export LFM2.5-VL (450M and 1.6B) as a multi-method PTE with three
methods (vision_encoder, token_embedding, text_decoder), all delegated
to the CUDA/AOTI backend.

New files under examples/models/lfm2_5_vl/: model, weight converter,
export script, and config JSONs.

Modifications to existing files are kept minimal:
- examples/models/lfm2/short_conv.py: replace nn.Conv1d(groups=dim) call
  with manual pointwise multiply+sum. Triton has no template for
  depthwise conv1d with dynamic sequence length. Mutable buffer state
  for conv_state is unchanged — AOTI handles it via mark_static_address
  at export time, same mechanism as the KV cache in MHA.
- exir/emit/_emitter.py: copy CUDA tensor storage to CPU before ctypes
  pointer read during constant serialization. Prevents segfault when
  exporting a model whose parameters live on CUDA.

Tested on NVIDIA B300 (CUDA 13.0, torch 2.11): 333-400 decode tok/s,
435-454 prefill tok/s, coherent generation on text-only and vision-
language prompts via both the Python pybindings and the llama_main
C++ runner.
@vincentzed vincentzed force-pushed the vz-lfm2516b-squashed branch from baf48bb to 3544f0b Compare April 14, 2026 20:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: cuda Issues related to the AOTI CUDA backend module: llm Issues related to LLM examples and apps, and to the extensions/llm/ code

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants